{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2599100f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/drive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9487c7b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Import Libraries '''\n",
    "\n",
    "import os\n",
    "from tensorflow.keras.models import load_model\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac803be",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Load Model '''\n",
    "\n",
    "curr_path = os.getcwd()\n",
    "model = load_model(curr_path + '/trained models/Five-classification.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8efd15a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Load Data '''\n",
    "path = 'data path'\n",
    "\n",
    "calssificatin_type = {\"binary\":1,\"superclasses\":5,\"subclasses\":23}\n",
    "classification_name=\"superclasses\"\n",
    "no_of_classes=calssificatin_type[classification_name]\n",
    "\n",
    "lead_type={\"lead-I\":1, \"bipolar-limb\":3 , \"unipolar-limb\":3, \"limb-leads\":6 , \"precordial-leads\":6,\"all-lead\":12}\n",
    "lead_name= \"all-lead\"\n",
    "no_of_leads=lead_type[lead_name]\n",
    "\n",
    "x_test  = np.load(path+'x_test.npy',allow_pickle=True)\n",
    "y_test  = np.load(path + 'y_test_sub.npy',allow_pickle=True)\n",
    "\n",
    "x_test  = x_test.transpose(0, 2, 1)\n",
    "\n",
    "x_test  = x_test.reshape(2198, no_of_leads, 1000, 1)\n",
    "\n",
    "\n",
    "print(\"x_test  :\", x_test.shape)\n",
    "print(\"y_test  :\", y_test.shape)\n",
    "print('Data loaded')\n",
    "\n",
    "\n",
    "from sklearn.preprocessing import MultiLabelBinarizer\n",
    "\n",
    "if classification_name!=\"binary\":\n",
    "    mlb = MultiLabelBinarizer()\n",
    "    mlb.fit(y_train)\n",
    "    y_train = mlb.transform(y_train)\n",
    "\n",
    "    mlb = MultiLabelBinarizer()\n",
    "    mlb.fit(y_test)\n",
    "    y_test = mlb.transform(y_test)\n",
    "    print('Data proocessed')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a39f7e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Evaluate the model\"\"\"\n",
    "\n",
    "from sklearn.metrics import precision_recall_curve, f1_score, roc_auc_score, accuracy_score, auc\n",
    "\n",
    "\n",
    "def sklearn_metrics(y_true, y_pred):\n",
    "    y_bin = np.copy(y_pred)\n",
    "    y_bin[y_bin >= 0.5] = 1\n",
    "    y_bin[y_bin < 0.5]  = 0\n",
    "\n",
    "    # Compute area under precision-Recall curve\n",
    "    auc_sum = 0\n",
    "    for i in range(no_of_classes):\n",
    "      precision, recall, thresholds = precision_recall_curve(y_true[:, i], y_pred[:,i])\n",
    "      auc_sum += auc(recall, precision)\n",
    "\n",
    "    print(\"Accuracy        : {:.2f}\".format(accuracy_score(y_true.flatten(), y_bin.flatten())* 100))\n",
    "    print(\"Macro AUC score : {:.2f}\".format(roc_auc_score(y_true, y_pred, average='macro') * 100))\n",
    "    print('AUPRC           : {:.2f}'.format((auc_sum / 5) * 100))\n",
    "    print(\"Micro F1 score  : {:.2f}\".format(f1_score(y_true, y_bin, average='micro') * 100))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7acb134d",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_prob = model.predict(x_test)\n",
    "print(\"\\nTest\")\n",
    "sklearn_metrics(y_test, y_pred_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
